module Data.Serialize.Get where
import Prelude

import Control.Alt (class Alt)
import Control.Alternative (class Alternative)
import Control.Monad.Error.Class (class MonadThrow, throwError)
import Control.MonadPlus (class MonadPlus, class MonadZero)
import Control.Plus (class Plus)
import Data.Array (drop, reverse, take, (:))
import Data.Array.Partial (init)
import Data.ByteString (Encoding(..), toString)
import Data.ByteString as B
import Data.Either (Either(..))
import Data.Foldable (fold, intercalate)
import Data.Integral (fromIntegral)
import Data.Maybe (Maybe(..), fromJust, fromMaybe)
import Data.Tuple.Nested ((/\))
import Data.UInt as U
import Data.VarInt as VI
import Data.Word (Word16(..), Word32(..), Word8(..))
import Partial.Unsafe (unsafePartial)
import Type.Quotient (runQuotient)
import Unsafe.Coerce (unsafeCoerce)

data Result r = Fail String B.ByteString
              | Partial (B.ByteString -> Result r)
              | Done r B.ByteString
             
instance showResult ::Show r => Show (Result r) where
    show (Fail msg _) = "Fail " <> show msg
    show (Partial _)  = "Partial _"
    show (Done r bs)  = "Done " <> show r <> " " <> show bs

instance functorResult :: Functor Result where
    map _ (Fail msg rest) = Fail msg rest
    map f (Partial k)     = Partial (map f <<< k)
    map f (Done r bs) = Done (f r) bs

type Input  = B.ByteString
type Buffer = Maybe B.ByteString

data More
  = Complete
  | Incomplete (Maybe Int)

derive instance eqMore :: Eq More

type Failure   r = Input -> Buffer -> More -> Array String -> String -> Result r
type Success a r = Input -> Buffer -> More -> Int -> a -> Result r



newtype Get a = Get
  { unGet :: forall r. Input -> Buffer -> More
                       -> Int -> Failure r -> Success a r -> Result r }

emptyBuffer :: Buffer
emptyBuffer = Just B.empty

extendBuffer :: Buffer -> B.ByteString -> Buffer
extendBuffer buf chunk =
  do bs <- buf
     pure $ bs <> chunk

append :: Buffer -> Buffer -> Buffer
append l r = (<>)`map` l <*> r

bufferBytes :: Buffer -> B.ByteString
bufferBytes = fromMaybe B.empty

moreLength :: More -> Int
moreLength m = case m of
  Complete      -> 0
  Incomplete mb -> fromMaybe 0 mb

instance getFunctor::Functor Get where
    map p (Get m) =  Get $ {unGet:(\s0 b0 m0 w0 kf ks -> m.unGet  s0 b0 m0 w0 kf $ \s1 b1 m1 w1 a -> ks s1 b1 m1 w1 (p a))}

instance applyGet :: Apply Get where
  apply (Get f)  (Get x) = Get {unGet:(\s0 b0 m0 w0 kf ks -> f.unGet s0 b0 m0 w0 kf $ \ s1 b1 m1 w1 g ->
                                 x.unGet  s1 b1 m1 w1 kf $ \ s2 b2 m2 w2 y -> ks s2 b2 m2 w2 (g y) )}

instance applicativeGet :: Applicative Get where
  pure = \a -> Get {unGet:(\s0 b0 m0 w _ ks -> ks s0 b0 m0 w a) }

instance altGet :: Alt Get where
   alt (Get a) (Get b) =
      Get {unGet:(\s0 b0 m0 w0 kf ks ->
        let ks' s1 b1        = ks s1 (b0 <> b1)
            kf' _  b1 m1     = kf (s0 <> bufferBytes b1)
                                  (b0 <> b1) m1
            try _  b1 m1 _ _ = b.unGet (s0 <> bufferBytes b1)
                                       b1 m1 w0 kf' ks'
        in a.unGet  s0 emptyBuffer m0 w0 try ks')}

instance plusGet :: Plus Get where
 empty = failDesc "mzero"

instance alternativeGet :: Alternative Get

instance bindGet :: Bind Get where
   bind (Get m) g = Get $ {unGet: (\s0 b0 m0 w0 kf ks ->
     m.unGet s0 b0 m0 w0 kf $ \s1 b1 m1 w1 a ->  case (g a) of (Get mb) -> mb.unGet s1 b1 m1 w1 kf ks)}

instance monadGet :: Monad Get
instance monadZeroGet :: MonadZero Get
instance monadPlusGet :: MonadPlus Get

instance monadThrowGet :: MonadThrow String Get where
  throwError e = failDesc e

--------------------------------------------------
formatTrace :: Array String -> String
formatTrace [] = "Empty call stack"
formatTrace ls = "From:\t" <> intercalate "\n\t" ls <> "\n"

get :: Get B.ByteString
get  = Get $ {unGet:(\s0 b0 m0 w _ k -> k s0 b0 m0 w s0)}

put :: B.ByteString -> Int -> Get Unit
put s w = Get $ {unGet:(\_ b0 m _ _ k -> k s b0 m w unit)} 

label ::forall a. String -> Get a -> Get a
label l (Get m) =
  Get $ {unGet:(\s0 b0 m0 w0 kf ks ->
    let kf' s1 b1 m1 ls = kf s1 b1 m1 (l:ls)
    in m.unGet s0 b0 m0 w0 kf' ks)}

finalK ::forall a. Success a a
finalK s _ _ _ a = Done a s

failK ::forall a. Failure a
failK s b _ ls msg = Fail (fold [msg, formatTrace ls]) (s <> bufferBytes b)

runGet :: forall a. Get a -> B.ByteString -> Either String a
runGet (Get m) str =
  case m.unGet  str Nothing Complete 0 failK finalK of
    Fail i _  -> Left i
    Done a _  -> Right a
    Partial _ -> Left "Failed reading: Internal error: unexpected Partial."

failDesc ::forall a. String -> Get a
failDesc err = do
    let msg = "Failed reading: " <> err
    Get $ {unGet:(\s0 b0 m0 _ kf _ -> kf s0 b0 m0 [] msg)} 

-----------------------------------------------------------
ensure ::Int -> Get B.ByteString
ensure n0 = Get $ {unGet:unsafeCoerce $ unGetFn}
 where
   --unGetFn::forall r.  B.ByteString  -> Maybe B.ByteString -> More -> Int ->  Failure r -> Success B.ByteString r -> Result r
   unGetFn s0 b0 m0 w0 kf ks = let
    n' = n0 - B.length s0
    in if n' <= 0
        then ks s0 b0 m0 w0 s0
        else getMore n' s0 [] b0 m0 w0 kf ks
   finalInput s0 ss = fold (reverse (s0 : ss))
   finalBuffer b0 s0 ss = extendBuffer b0 (fold (reverse (unsafePartial $ init (s0 : ss))))
   --getMore::forall r.Int -> B.ByteString -> Array B.ByteString -> Maybe B.ByteString -> More -> Int -> Failure r -> Success B.ByteString r -> Result r
   getMore n s0 ss b0 m0 w0 kf ks = let
            tooFewBytes = let
                 s = finalInput s0 ss
                 b = finalBuffer b0 s0 ss
                 in kf s b m0 ["demandInput"] "too few bytes"
            in case m0 of
                Complete -> tooFewBytes
                Incomplete mb -> Partial $ \s ->
                    if ((B.length s) == 0)
                        then tooFewBytes
                        else let
                            mb' = case mb of
                                Just l -> Just $ l - B.length s
                                Nothing -> Nothing
                            in checkIfEnough n s (s0 : ss) b0 (Incomplete mb') w0 kf ks
   checkIfEnough n s0 ss b0 m0 w0 kf ks = let
         n' = n - B.length s0
            in if n' <= 0
                then let
                    s = finalInput s0 ss
                    b = finalBuffer b0 s0 ss
                    in ks s b m0 w0 s
                else getMore n' s0 ss b0 m0 w0 kf ks


---------------------------------------------------------
getBytes :: Int -> Get B.ByteString
getBytes n | n < 0 = throwError "getBytes: negative length requested"
getBytes n = do
    s <- ensure n
    let consume = takeBS n s
        rest    = dropBS n s
    cur <- bytesRead
    put rest (cur + n)
    pure consume

takeBS ::Int -> B.ByteString -> B.ByteString
takeBS n bs = B.pack $ take n $  B.unpack bs 
dropBS::Int -> B.ByteString -> B.ByteString
dropBS n bs = B.pack $ drop n $  B.unpack bs 

bytesRead :: Get Int
bytesRead = Get {unGet:(\i b m w _ k -> k i b m w w)} 

getVarInt :: Get Int
getVarInt  = do
   s <- get
   let num /\ len = VI.decode (B.unsafeThaw s) 0
   cur <- bytesRead
   let rest    = dropBS len s
   put rest (cur + len)
   pure num

------------------------------------------------------------
otoW8::B.Octet -> Word8
otoW8 o = fromIntegral $ runQuotient o


getWord8 :: Get Word8
getWord8 = do
    s <- getBytes 1
    let o = unsafePartial $ fromJust $ B.head s
    pure $ otoW8 o

getWord16be :: Get Word16
getWord16be = do
   s <- getBytes 2
   pure $ word16be s

word16be :: B.ByteString -> Word16
word16be bs = Word16 $ (U.shl a  (U.fromInt 8))  `U.or`  b 
   where
   (Word8 a)  = otoW8 $ B.unsafeIndex bs 0
   (Word8 b)  = otoW8 $ B.unsafeIndex bs 1

getWord32be :: Get Word32
getWord32be = do
  s <- getBytes 4
  pure $ word32be s

word32be :: B.ByteString -> Word32
word32be bs = Word32 $ (
                        (U.shl a  (U.fromInt 24)) `U.or` 
                        (U.shl b  (U.fromInt 16)) `U.or`  
                        (U.shl c  (U.fromInt 8))  `U.or`  d)
   where
   (Word8 a)  = otoW8 $ B.unsafeIndex bs 0
   (Word8 b)  = otoW8 $ B.unsafeIndex bs 1
   (Word8 c)  = otoW8 $ B.unsafeIndex bs 2
   (Word8 d)  = otoW8 $ B.unsafeIndex bs 3


getListOf ::forall a. Get a -> Get (Array a)
getListOf m = go []
  where
  go as = do
     a <- m 
     s <- get
     --let _ = unsafePerformEffect $ logShow s
     if (B.length s == 0) then pure (reverse (a:as)) else go (a:as)
    
remaining :: Get Int
remaining = Get {unGet: (\ s0 b0 m0 w0 _ ks -> ks s0 b0 m0 w0 (B.length s0 + moreLength m0))}

getRemainByteString::Get B.ByteString
getRemainByteString = do
  si <- remaining
  getBytes si

getRemainString::Get String
getRemainString = do
   bs <- getRemainByteString
   pure $ toString bs UTF8